iT邦幫忙

2024 iThome 鐵人賽

DAY 25
0
Software Development

LSTM結合Yolo v8對於多隻斑馬魚行為分析系列 第 25

day 25 lstm結合yolo分析斑馬魚行為模型系統

  • 分享至 

  • xImage
  •  

今天是第二十五天我們可以寫一個lstm結合yolo去分析斑馬的模型系統,我們可以先看我們的模型準不準,以下是程式碼

1. YOLOv8 模型構建與檢測

首先,YOLOv8需要用於檢測斑馬魚的位置。你可以使用Ultralytics YOLOv8框架進行模型的加載和檢測。

from ultralytics import YOLO
import cv2
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense, Dropout, BatchNormalization
from tensorflow.keras.optimizers import Adam
from sklearn.model_selection import train_test_split
import os

# YOLOv8 模型初始化
yolo_model = YOLO('yolov8n.pt')  # 可以選擇不同的YOLOv8模型,根據需求選擇n, s, m, l, x等模型

# 檢測斑馬魚的位置
def detect_fish(frame):
    results = yolo_model(frame)  # 進行目標檢測
    detections = []
    for result in results.xyxy[0].cpu().numpy():  # 提取檢測結果
        x_min, y_min, x_max, y_max, confidence, class_id = result
        if class_id == 0:  # 假設斑馬魚的class_id為0
            x_center = (x_min + x_max) / 2
            y_center = (y_min + y_max) / 2
            detections.append((x_center, y_center))
    return detections

2. 數據集準備與處理

YOLOv8模型檢測到斑馬魚的位置後,我們需要將其轉換為LSTM模型的輸入格式。這部分代碼將視頻幀或圖像序列轉換為LSTM所需的數據。

def prepare_data(frames, lookback=9):
    X, y = [], []
    for i in range(len(frames) - lookback):
        input_seq = []
        for j in range(lookback):
            detected_fish = detect_fish(frames[i + j])
            input_seq.append(detected_fish)
        X.append(input_seq)
        y.append(detect_fish(frames[i + lookback]))  # 下一幀的真實位置
    return np.array(X), np.array(y)

def load_video_frames(video_path):
    cap = cv2.VideoCapture(video_path)
    frames = []
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        frames.append(frame)
    cap.release()
    return frames

3. LSTM 模型構建與訓練

接下來,我們會構建一個更複雜的LSTM模型,並使用檢測到的數據進行訓練。模型包括多層LSTM、Dropout、BatchNormalization等層。

def create_complex_lstm_model(input_shape):
    model = Sequential()
    model.add(LSTM(256, input_shape=input_shape, return_sequences=True))
    model.add(Dropout(0.3))
    model.add(BatchNormalization())

    model.add(LSTM(128, return_sequences=True))
    model.add(Dropout(0.3))
    model.add(BatchNormalization())

    model.add(LSTM(64))
    model.add(Dropout(0.3))

    model.add(Dense(32, activation='relu'))
    model.add(Dense(2))  # 輸出斑馬魚的未來位置(x, y)
    model.compile(optimizer=Adam(learning_rate=0.0005), loss='mse')
    return model

4. 模型訓練與選擇

最後,將數據集拆分為訓練集和測試集,訓練LSTM模型並選擇最佳模型。

# 載入視頻幀
video_path = 'zebrafish_video.mp4'
frames = load_video_frames(video_path)

# 準備數據
lookback = 9
X, y = prepare_data(frames, lookback)
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)

# 模型訓練
input_shape = (X_train.shape[1], X_train.shape[2], 2)  # lookback, number of fish, (x, y)
model = create_complex_lstm_model(input_shape)
history = model.fit(X_train, y_train, epochs=200, batch_size=32, validation_data=(X_val, y_val))

# 選擇最佳模型
val_loss = history.history['val_loss']
best_epoch = np.argmin(val_loss)
best_model = model
best_model.save(f'best_lstm_yolo_model_epoch_{best_epoch}.h5')

print(f"最佳模型儲存在第{best_epoch+1}輪訓練後,驗證集損失為{val_loss[best_epoch]:.4f}")

1. YOLOv8 模型構建與檢測

from ultralytics import YOLO
import cv2
import numpy as np

這一部分導入了所需的庫:

  • ultralytics.YOLO 是 YOLOv8 的一個實現庫,它提供了簡便的接口來加載和使用 YOLOv8 模型。
  • cv2 是 OpenCV 庫,用於處理圖像和視頻。cv2.VideoCapture 可以用來讀取視頻幀。
  • numpy 是一個常用的數據處理庫,用於進行數學運算和處理多維數組。
# YOLOv8 模型初始化
yolo_model = YOLO('yolov8n.pt')

模型初始化:

  • yolo_model = YOLO('yolov8n.pt') 加載了 YOLOv8 的一個預訓練模型,這裡使用的是 yolov8n.pt,即 YOLOv8 的 nano 版本,這是一個相對輕量級的模型,適合資源有限的設備進行推理。
# 檢測斑馬魚的位置
def detect_fish(frame):
    results = yolo_model(frame)
    detections = []
    for result in results.xyxy[0].cpu().numpy():
        x_min, y_min, x_max, y_max, confidence, class_id = result
        if class_id == 0:  # 假設斑馬魚的class_id為0
            x_center = (x_min + x_max) / 2
            y_center = (y_min + y_max) / 2
            detections.append((x_center, y_center))
    return detections

目標檢測與位置計算:

  • results = yolo_model(frame) 這行代碼對單幀圖像進行目標檢測,YOLOv8 返回檢測結果,結果包括邊界框座標、置信度和類別ID。
  • for result in results.xyxy[0].cpu().numpy() 這行提取 YOLO 檢測的每一個目標,results.xyxy[0] 代表檢測出的邊界框數據。
  • (x_min, y_min, x_max, y_max) 是檢測出的邊界框的左上角和右下角座標。
  • class_id 表示該目標的類別標識符,if class_id == 0 假設斑馬魚的類別標識符為0。
  • (x_center, y_center) 是斑馬魚的中心位置,用來描述斑馬魚的座標。
  • 最終 detections 返回這些中心點的座標列表。

2. 數據集準備與處理

def prepare_data(frames, lookback=9):
    X, y = []
    for i in range(len(frames) - lookback):
        input_seq = []
        for j in range(lookback):
            detected_fish = detect_fish(frames[i + j])
            input_seq.append(detected_fish)
        X.append(input_seq)
        y.append(detect_fish(frames[i + lookback]))  # 下一幀的真實位置
    return np.array(X), np.array(y)

準備序列數據:

  • prepare_data 函數將一系列幀圖像轉換為 LSTM 模型的訓練數據。
  • lookback=9 指定 LSTM 模型回顧的時間步長,也就是模型在預測時需要參考前多少幀的數據。
  • for i in range(len(frames) - lookback) 遍歷每一個可能的序列,lookback 决定了能構成多少個序列。
  • input_seq.append(detected_fish) 用 YOLOv8 檢測每一幀的斑馬魚位置,並將其添加到 input_seq 中。
  • y.append(detect_fish(frames[i + lookback])) 則存儲下一幀的真實位置作為目標輸出,用來訓練 LSTM。
def load_video_frames(video_path):
    cap = cv2.VideoCapture(video_path)
    frames = []
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        frames.append(frame)
    cap.release()
    return frames

讀取視頻幀:

  • cap = cv2.VideoCapture(video_path) 使用 OpenCV 讀取指定路徑的視頻文件。
  • while cap.isOpened() 確保視頻流被成功打開,cap.read() 讀取每一幀圖像。
  • frames.append(frame) 將每一幀圖像存儲在列表 frames 中。
  • 最終返回所有幀圖像的列表,用於後續處理。

3. LSTM 模型構建與訓練

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense, Dropout, BatchNormalization
from tensorflow.keras.optimizers import Adam
from sklearn.model_selection import train_test_split

導入LSTM相關庫:

  • Sequential 是 Keras 中用來構建模型的順序容器。
  • LSTM 是長短期記憶神經網絡層,用於處理序列數據。
  • Dense 是全連接層,用於將LSTM層的輸出映射到目標輸出。
  • Dropout 是一種正則化技術,用於防止過擬合。
  • BatchNormalization 是一種技術,用於加快訓練速度並穩定學習過程。
  • Adam 是一種常用的優化器,用於調整神經網絡的權重。
  • train_test_split 用於將數據集拆分為訓練集和驗證集。
def create_complex_lstm_model(input_shape):
    model = Sequential()
    model.add(LSTM(256, input_shape=input_shape, return_sequences=True))
    model.add(Dropout(0.3))
    model.add(BatchNormalization())

    model.add(LSTM(128, return_sequences=True))
    model.add(Dropout(0.3))
    model.add(BatchNormalization())

    model.add(LSTM(64))
    model.add(Dropout(0.3))

    model.add(Dense(32, activation='relu'))
    model.add(Dense(2))  # 輸出斑馬魚的未來位置(x, y)
    model.compile(optimizer=Adam(learning_rate=0.0005), loss='mse')
    return model

構建LSTM模型:

  • Sequential 容器: 定義一個順序模型,其中的層按順序排列。
  • LSTM(256, return_sequences=True): 第一層LSTM有256個單元,input_shape 指定了輸入的形狀,return_sequences=True 意味著該層輸出的每個時間步驟都將作為下一層的輸入。
  • Dropout(0.3): 丟棄30%的神經元來防止過擬合。
  • BatchNormalization: 標準化輸出,以穩定模型訓練。
  • 第二層LSTM(128): 第二層LSTM有128個單元,依然返回序列輸出。
  • 第三層LSTM(64): 第三層LSTM有64個單元,這層沒有 return_sequences=True,因此這層的輸出將成為後續層的單一輸入。
  • Dense(32, activation='relu'): 全連接層,將 LSTM 的輸出映射到32個神經元,並應用 ReLU 激活函數來引入非線性。
  • Dense(2): 最後的全連接層,輸出2個值,對應斑馬魚未來位置的 xy 坐標。
  • 模型編譯: 使用 Adam 優化器和均方誤差(MSE)損失函數來編譯模型,學習率設定為 0.0005。

4. 模型訓練與選擇

# 載入視頻幀
video_path = 'zebrafish_video.mp4'
frames = load_video_frames(video_path)

載入視頻幀:

  • 這部分代碼使用先前定義的 load_video_frames 函數載入給定視頻

的所有幀,並將它們存儲在 frames 列表中。

# 準備數據
lookback = 9
X, y = prepare_data(frames, lookback)
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)

數據準備:

  • lookback = 9 定義了 LSTM 模型的回顧步長。
  • X, y = prepare_data(frames, lookback) 使用先前的 prepare_data 函數來準備數據,X 是輸入序列,y 是對應的目標輸出。
  • train_test_split 將數據集拆分為訓練集和驗證集,驗證集佔20%。
# 模型訓練
input_shape = (X_train.shape[1], X_train.shape[2], 2)  # lookback, number of fish, (x, y)
model = create_complex_lstm_model(input_shape)
history = model.fit(X_train, y_train, epochs=200, batch_size=32, validation_data=(X_val, y_val))

模型訓練:

  • input_shape = (X_train.shape[1], X_train.shape[2], 2) 確定LSTM模型的輸入形狀,其中 X_train.shape[1] 是回顧步長,X_train.shape[2] 是斑馬魚的數量,每個斑馬魚有兩個值 xy
  • model.fit 函數用來訓練模型,epochs=200 表示模型會在整個數據集上訓練200次,batch_size=32 表示每次訓練32個樣本,validation_data 用於在每個 epoch 結束後計算驗證集的損失。
# 選擇最佳模型
val_loss = history.history['val_loss']
best_epoch = np.argmin(val_loss)
best_model = model
best_model.save(f'best_lstm_yolo_model_epoch_{best_epoch}.h5')

print(f"最佳模型儲存在第{best_epoch+1}輪訓練後,驗證集損失為{val_loss[best_epoch]:.4f}")

模型保存與選擇:

  • val_loss = history.history['val_loss'] 獲取訓練過程中每個 epoch 的驗證損失值。
  • best_epoch = np.argmin(val_loss) 找到驗證損失最小的 epoch,這個 epoch 對應最佳模型。
  • best_model.save(f'best_lstm_yolo_model_epoch_{best_epoch}.h5') 將最佳模型保存為 .h5 文件,文件名包含最佳 epoch 的數字。
  • 最後,使用 print 函數輸出最佳模型所在的訓練輪次以及對應的驗證損失值。

總結

這段程式碼結合了YOLOv8的目標檢測能力和LSTM模型的時間序列預測能力,能夠對視頻中的斑馬魚進行行為分析。YOLOv8用於檢測斑馬魚的位置,而LSTM則學習這些位置的時間序列模式,從而預測未來的行為。這是一個相對複雜的深度學習應用,適用於動物行為研究中需要分析大量序列數據的場景。


上一篇
day 24 Lstm結合yolo 對於多隻斑馬魚軌跡圖分析
下一篇
day 26 lstm多隻斑馬魚模型分析
系列文
LSTM結合Yolo v8對於多隻斑馬魚行為分析29
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言